/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "common/format/range_axis_util.h"

namespace fe {
const std::map<ge::Format, GetRangeAxisValueInfoByFormatPtr> RangeAxisUtil::get_range_axis_value_func_map = {
    {ge::FORMAT_NCHW, std::make_shared<GetRangeAxisValueInfoByFormat>(GetRangeAxisValueByNCHW)},
    {ge::FORMAT_NHWC, std::make_shared<GetRangeAxisValueInfoByFormat>(GetRangeAxisValueByNHWC)},
    {ge::FORMAT_NC1HWC0, std::make_shared<GetRangeAxisValueInfoByFormat>(GetRangeAxisValueByNC1HWC0)},
    {ge::FORMAT_FRACTAL_Z, std::make_shared<GetRangeAxisValueInfoByFormat>(GetRangeAxisValueByFz)},
    {ge::FORMAT_HWCN, std::make_shared<GetRangeAxisValueInfoByFormat>(GetRangeAxisValueByHWCN)},
    {ge::FORMAT_CHWN, std::make_shared<GetRangeAxisValueInfoByFormat>(GetRangeAxisValueByCHWN)},
    {ge::FORMAT_ND, std::make_shared<GetRangeAxisValueInfoByFormat>(GetRangeAxisValueByND)},
    {ge::FORMAT_NDHWC, std::make_shared<GetRangeAxisValueInfoByFormat>(GetRangeAxisValueByNDHWC)},
    {ge::FORMAT_NCDHW, std::make_shared<GetRangeAxisValueInfoByFormat>(GetRangeAxisValueByNCDHW)},
    /* The Last N of NHWCN is considered as Cout, which is the C o NDHWC */
    {ge::FORMAT_DHWCN, std::make_shared<GetRangeAxisValueInfoByFormat>(GetRangeAxisValueByDHWCN)},
    {ge::FORMAT_DHWNC, std::make_shared<GetRangeAxisValueInfoByFormat>(GetRangeAxisValueByDHWNC)}};

Status RangeAxisUtil::CheckParamValue(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                      const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                      vector<std::pair<int64_t, int64_t>>& range_value,
                                      const size_t& min_size = DIM_DEFAULT_SIZE) {
  if (range_value.size() < AXIS_BOTTOM) {
    FE_LOGW("rangeValue is empty!");
    return FAILED;
  }
  if (original_dim_vec.empty()) {
    FE_LOGW("Original dim vector is empty!");
    return FAILED;
  }
  if (original_dim_vec.size() < min_size) {
    FE_LOGW("Original dim vector size: %zu is less than %u!", original_dim_vec.size(), min_size);
    return FAILED;
  }
  if (original_dim_vec.size() != original_range_vec.size()) {
    FE_LOGW("Size of shape is different from size of range!");
    return FAILED;
  }
  if (c0 == 0) {
    FE_LOGE("c0 is zero!");
    return FAILED;
  }
  return SUCCESS;
}

Status RangeAxisUtil::GetRangeAxisValueByOriginFormat(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                                      const ge::Format& format, const vector<int64_t>& dim_vec,
                                                      const uint32_t& c0,
                                                      vector<std::pair<int64_t, int64_t>>& range_value) {
  auto iter_range_get_axis_func = get_range_axis_value_func_map.find(format);
  if (iter_range_get_axis_func == get_range_axis_value_func_map.end()) {
    FE_LOGW("Can not get range axis value of old format %u!", format);
    return FAILED;
  }
  GetRangeAxisValueInfoByFormatPtr get_range_axis_func = iter_range_get_axis_func->second;
  FE_CHECK_NOTNULL(get_range_axis_func);
  return (*get_range_axis_func)(original_range_vec, dim_vec, c0, range_value);
}

bool RangeAxisUtil::HasAxisValueFunc(const ge::Format& format) {
  auto iter_get_axis_func = get_range_axis_value_func_map.find(format);
  if (iter_get_axis_func == get_range_axis_value_func_map.end()) {
    FE_LOGW("Can not get range axis value of format %u!", format);
    return false;
  }
  return true;
}

Status RangeAxisUtil::GetRangeAxisValueByND(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                            const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                            vector<std::pair<int64_t, int64_t>>& range_value) {
  if (range_value.size() < AXIS_BOTTOM) {
    FE_LOGW("rangeValue is empty!");
    return FAILED;
  }
  if (original_dim_vec.empty()) {
    FE_LOGW("Original dim vector is empty!");
    return FAILED;
  }
  /* To differentiate the input datatype of int8 and others */
  std::pair<int64_t, int64_t> c0_range(c0, c0);
  range_value[AXIS_C0] = c0_range;

  FE_LOGD("Size of original_range_vec is %d, original_dim_vec is %d.",
          original_range_vec.size(), original_dim_vec.size());
  /* Check original_range_vec size, to avoid array bound */
  if ((original_dim_vec.size() == NCHW_DIMENSION_NUM) && (original_range_vec.size() == NCHW_DIMENSION_NUM)) {
    range_value[AXIS_N] = original_range_vec[NCHW_DIM_N];
    range_value[AXIS_C] = original_range_vec[NCHW_DIM_C];
    range_value[AXIS_H] = original_range_vec[NCHW_DIM_H];
    range_value[AXIS_W] = original_range_vec[NCHW_DIM_W];
    int64_t c1_first_range = DivisionCeiling(original_range_vec[NCHW_DIM_C].first, (int64_t)c0);
    int64_t c1_second_range = DivisionCeiling(original_range_vec[NCHW_DIM_C].second, (int64_t)c0);

    range_value[AXIS_C1] = std::pair<int64_t, int64_t>(c1_first_range, c1_second_range);
    range_value[AXIS_Co] = c0_range;
  }
  return SUCCESS;
}

Status RangeAxisUtil::GetRangeAxisValueByNCHW(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                              const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                              vector<std::pair<int64_t, int64_t>>& range_value) {
  /* C0 Must be set for case ND or 2D-NCHW to NZ */
  std::pair<int64_t, int64_t> c0_range(c0, c0);
  range_value[AXIS_C0] = c0_range;
  if (CheckParamValue(original_range_vec, original_dim_vec, c0, range_value) != SUCCESS) {
    FE_LOGW("Parameter is invalid!");
    return FAILED;
  }

  range_value[AXIS_N] = original_range_vec[NCHW_DIM_N];
  range_value[AXIS_C] = original_range_vec[NCHW_DIM_C];
  range_value[AXIS_H] = original_range_vec[NCHW_DIM_H];
  range_value[AXIS_W] = original_range_vec[NCHW_DIM_W];
  int64_t c1_first_range = DivisionCeiling(original_range_vec[NCHW_DIM_C].first, (int64_t)c0);
  int64_t c1_second_range = DivisionCeiling(original_range_vec[NCHW_DIM_C].second, (int64_t)c0);
  range_value[AXIS_C1] = std::pair<int64_t, int64_t>(c1_first_range, c1_second_range);
  range_value[AXIS_Co] = c0_range;
  return SUCCESS;
}

Status RangeAxisUtil::GetRangeAxisValueByNHWC(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                              const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                              vector<std::pair<int64_t, int64_t>>& range_value) {
  /* C0 Must be set for case ND or 2D-NHWC to NZ */
  std::pair<int64_t, int64_t> c0_range(c0, c0);
  range_value[AXIS_C0] = c0_range;
  if (CheckParamValue(original_range_vec, original_dim_vec, c0, range_value) != SUCCESS) {
    FE_LOGW("Parameter is invalid!");
    return FAILED;
  }

  range_value[AXIS_N] = original_range_vec[NHWC_DIM_N];
  range_value[AXIS_C] = original_range_vec[NHWC_DIM_C];
  range_value[AXIS_H] = original_range_vec[NHWC_DIM_H];
  range_value[AXIS_W] = original_range_vec[NHWC_DIM_W];
  int64_t c1_first_range = DivisionCeiling(original_range_vec[NHWC_DIM_C].first, (int64_t)c0);
  int64_t c1_second_range = DivisionCeiling(original_range_vec[NHWC_DIM_C].second, (int64_t)c0);
  range_value[AXIS_C1] = std::pair<int64_t, int64_t>(c1_first_range, c1_second_range);
  range_value[AXIS_Co] = c0_range;
  return SUCCESS;
}

Status RangeAxisUtil::GetRangeAxisValueByNC1HWC0(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                                 const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                                 vector<std::pair<int64_t, int64_t>>& range_value) {
  if (CheckParamValue(original_range_vec, original_dim_vec, c0, range_value) != SUCCESS) {
    FE_LOGW("Parameter is invalid!");
    return FAILED;
  }
  auto dim_size = original_dim_vec.size();
  if (dim_size == DIM_DEFAULT_SIZE + 1) {
    range_value[AXIS_C1] = original_range_vec[NC1HWC0_DIM_C1];
    range_value[AXIS_C0] = original_range_vec[NC1HWC0_DIM_C0];
    FE_INT64_MULCHECK(range_value[AXIS_C1].first, range_value[AXIS_C0].first);
    FE_INT64_MULCHECK(range_value[AXIS_C1].second, range_value[AXIS_C0].second);
    range_value[AXIS_C] = std::pair<int64_t, int64_t>(range_value[AXIS_C1].first * range_value[AXIS_C0].first,
                                                      range_value[AXIS_C1].second * range_value[AXIS_C0].second);
  } else {
    int64_t c1_first_range = DivisionCeiling(original_range_vec[NCHW_DIM_C].first, (int64_t)c0);
    int64_t c1_second_range = DivisionCeiling(original_range_vec[NCHW_DIM_C].second, (int64_t)c0);
    range_value[AXIS_C1] = std::pair<int64_t, int64_t>(c1_first_range, c1_second_range);
    range_value[AXIS_C0] = std::pair<int64_t, int64_t>(c0, c0);
    range_value[AXIS_C] = original_range_vec[NCHW_DIM_C];
  }

  range_value[AXIS_N] = original_range_vec[NCHW_DIM_N];
  range_value[AXIS_H] = original_range_vec[NCHW_DIM_H];
  range_value[AXIS_W] = original_range_vec[NCHW_DIM_W];
  return SUCCESS;
}

/* !!!!Deprecated!!!! For current stage, we consider fz as nchw.
 * Actually, it is {HWC/16, N, 16,16} */
Status RangeAxisUtil::GetRangeAxisValueByFz(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                            const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                            vector<std::pair<int64_t, int64_t>>& range_value) {
  if (CheckParamValue(original_range_vec, original_dim_vec, c0, range_value) != SUCCESS) {
    FE_LOGW("Parameter is invalid!");
    return FAILED;
  }
  range_value[AXIS_N] = original_range_vec[NCHW_DIM_N];
  range_value[AXIS_C] = original_range_vec[NCHW_DIM_C];
  range_value[AXIS_H] = original_range_vec[NCHW_DIM_H];
  range_value[AXIS_W] = original_range_vec[NCHW_DIM_W];
  int64_t c1_first_range = DivisionCeiling(original_range_vec[NCHW_DIM_C].first, (int64_t)c0);
  int64_t c1_second_range = DivisionCeiling(original_range_vec[NCHW_DIM_C].second, (int64_t)c0);
  range_value[AXIS_C1] = std::pair<int64_t, int64_t>(c1_first_range, c1_second_range);
  range_value[AXIS_C0] = std::pair<int64_t, int64_t>(c0, c0);
  return SUCCESS;
}

Status RangeAxisUtil::GetRangeAxisValueByHWCN(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                              const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                              vector<std::pair<int64_t, int64_t>>& range_value) {
  /* C0 Must be set for case ND or 2D-HWCN to NZ */
  std::pair<int64_t, int64_t> c0_range(c0, c0);
  range_value[AXIS_C0] = c0_range;
  if (CheckParamValue(original_range_vec, original_dim_vec, c0, range_value) != SUCCESS) {
    FE_LOGW("Parameter is invalid!");
    return FAILED;
  }

  range_value[AXIS_N] = original_range_vec[HWCN_DIM_N];
  range_value[AXIS_C] = original_range_vec[HWCN_DIM_C];
  range_value[AXIS_H] = original_range_vec[HWCN_DIM_H];
  range_value[AXIS_W] = original_range_vec[HWCN_DIM_W];
  int64_t c1_first_range = DivisionCeiling(original_range_vec[HWCN_DIM_C].first, (int64_t)c0);
  int64_t c1_second_range = DivisionCeiling(original_range_vec[HWCN_DIM_C].second, (int64_t)c0);
  range_value[AXIS_C1] = std::pair<int64_t, int64_t>(c1_first_range, c1_second_range);
  range_value[AXIS_Co] = c0_range;
  return SUCCESS;
}

Status RangeAxisUtil::GetRangeAxisValueByCHWN(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                              const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                              vector<std::pair<int64_t, int64_t>>& range_value) {
  /* C0 Must be set for case ND or 2D-CHWN to NZ */
  std::pair<int64_t, int64_t> c0_range(c0, c0);
  range_value[AXIS_C0] = c0_range;
  if (CheckParamValue(original_range_vec, original_dim_vec, c0, range_value) != SUCCESS) {
    FE_LOGW("Parameter is invalid!");
    return FAILED;
  }

  range_value[AXIS_N] = original_range_vec[CHWN_DIM_N];
  range_value[AXIS_C] = original_range_vec[CHWN_DIM_C];
  range_value[AXIS_H] = original_range_vec[CHWN_DIM_H];
  range_value[AXIS_W] = original_range_vec[CHWN_DIM_W];
  int64_t c1_first_range = DivisionCeiling(original_range_vec[CHWN_DIM_C].first, (int64_t)c0);
  int64_t c1_second_range = DivisionCeiling(original_range_vec[CHWN_DIM_C].second, (int64_t)c0);
  range_value[AXIS_C1] = std::pair<int64_t, int64_t>(c1_first_range, c1_second_range);
  range_value[AXIS_Co] = c0_range;
  return SUCCESS;
}

Status RangeAxisUtil::GetRangeAxisValueByNDHWC(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                               const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                               vector<std::pair<int64_t, int64_t>>& range_value) {
  std::pair<int64_t, int64_t> c0_range(c0, c0);
  range_value[AXIS_C0] = c0_range;
  if (CheckParamValue(original_range_vec, original_dim_vec, c0, range_value, DIMENSION_NUM_FIVE) != SUCCESS) {
    FE_LOGW("Parameter is invalid!");
    return FAILED;
  }

  range_value[AXIS_N] = original_range_vec[NDHWC_DIM_N];
  range_value[AXIS_C] = original_range_vec[NDHWC_DIM_C];
  range_value[AXIS_H] = original_range_vec[NDHWC_DIM_H];
  range_value[AXIS_W] = original_range_vec[NDHWC_DIM_W];
  int64_t c1_first_range = DivisionCeiling(original_range_vec[NDHWC_DIM_C].first, (int64_t)c0);
  int64_t c1_second_range = DivisionCeiling(original_range_vec[NDHWC_DIM_C].second, (int64_t)c0);
  range_value[AXIS_C1] = std::pair<int64_t, int64_t>(c1_first_range, c1_second_range);
  range_value[AXIS_Co] = c0_range;
  range_value[AXIS_D] = original_range_vec[NDHWC_DIM_D];
  return SUCCESS;
}

Status RangeAxisUtil::GetRangeAxisValueByNCDHW(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                               const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                               vector<std::pair<int64_t, int64_t>>& range_value) {
  std::pair<int64_t, int64_t> c0_range(c0, c0);
  range_value[AXIS_C0] = c0_range;
  if (CheckParamValue(original_range_vec, original_dim_vec, c0, range_value, DIMENSION_NUM_FIVE) != SUCCESS) {
    FE_LOGW("Parameter is invalid!");
    return FAILED;
  }

  range_value[AXIS_N] = original_range_vec[NCDHW_DIM_N];
  range_value[AXIS_C] = original_range_vec[NCDHW_DIM_C];
  range_value[AXIS_H] = original_range_vec[NCDHW_DIM_H];
  range_value[AXIS_W] = original_range_vec[NCDHW_DIM_W];
  int64_t c1_first_range = DivisionCeiling(original_range_vec[NCDHW_DIM_C].first, (int64_t)c0);
  int64_t c1_second_range = DivisionCeiling(original_range_vec[NCDHW_DIM_C].second, (int64_t)c0);
  range_value[AXIS_C1] = std::pair<int64_t, int64_t>(c1_first_range, c1_second_range);
  range_value[AXIS_Co] = c0_range;
  range_value[AXIS_D] = original_range_vec[NCDHW_DIM_D];
  return SUCCESS;
}

Status RangeAxisUtil::GetRangeAxisValueByDHWCN(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                               const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                               vector<std::pair<int64_t, int64_t>>& range_value) {
  std::pair<int64_t, int64_t> c0_range(c0, c0);
  range_value[AXIS_C0] = c0_range;
  if (CheckParamValue(original_range_vec, original_dim_vec, c0, range_value, DIMENSION_NUM_FIVE) != SUCCESS) {
    FE_LOGW("Parameter is invalid!");
    return FAILED;
  }

  range_value[AXIS_N] = original_range_vec[DHWCN_DIM_N];
  range_value[AXIS_C] = original_range_vec[DHWCN_DIM_C];
  range_value[AXIS_H] = original_range_vec[DHWCN_DIM_H];
  range_value[AXIS_W] = original_range_vec[DHWCN_DIM_W];
  int64_t c1_first_range = DivisionCeiling(original_range_vec[DHWCN_DIM_C].first, (int64_t)c0);
  int64_t c1_second_range = DivisionCeiling(original_range_vec[DHWCN_DIM_C].second, (int64_t)c0);
  range_value[AXIS_C1] = std::pair<int64_t, int64_t>(c1_first_range, c1_second_range);
  range_value[AXIS_Co] = c0_range;
  range_value[AXIS_D] = original_range_vec[DHWCN_DIM_D];
  return SUCCESS;
}

Status RangeAxisUtil::GetRangeAxisValueByDHWNC(const vector<std::pair<int64_t, int64_t>>& original_range_vec,
                                               const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                               vector<std::pair<int64_t, int64_t>>& range_value) {
  std::pair<int64_t, int64_t> c0_range(c0, c0);
  range_value[AXIS_C0] = c0_range;
  if (CheckParamValue(original_range_vec, original_dim_vec, c0, range_value, DIMENSION_NUM_FIVE) != SUCCESS) {
    FE_LOGW("Parameter is invalid!");
    return FAILED;
  }

  range_value[AXIS_N] = original_range_vec[DHWNC_DIM_N];
  range_value[AXIS_C] = original_range_vec[DHWNC_DIM_C];
  range_value[AXIS_H] = original_range_vec[DHWNC_DIM_H];
  range_value[AXIS_W] = original_range_vec[DHWNC_DIM_W];
  int64_t c1_first_range = DivisionCeiling(original_range_vec[DHWNC_DIM_C].first, (int64_t)c0);
  int64_t c1_second_range = DivisionCeiling(original_range_vec[DHWNC_DIM_C].second, (int64_t)c0);
  range_value[AXIS_C1] = std::pair<int64_t, int64_t>(c1_first_range, c1_second_range);
  range_value[AXIS_Co] = c0_range;
  range_value[AXIS_D] = original_range_vec[DHWNC_DIM_D];
  return SUCCESS;
}
};  // namespace fe
